{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 朴素贝叶斯分类器——垃圾邮件分类\n",
    "\n",
    "1. 收集数据：提供文本文件\n",
    "2. 准备数据：将文本文件解析成词向量\n",
    "3. 分析数据：检查词条确保解析正确性\n",
    "4. 训练算法：使用trainBayes方法\n",
    "5. 测试算法：使用classifyNB，构建一个新的测试函数计算文档的错误率\n",
    "6. 使用算法：构建一个完整的程序对一组文档进行分类，将错分的文档输出"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 准备数据：切分文本"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import random\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "#分词\n",
    "def textParse(text):\n",
    "    listOfTokens = re.split(r'\\W+', text)\n",
    "    return [tok.lower() for tok in listOfTokens if len(tok) > 2]\n",
    "#创建词表\n",
    "def createVocabList(dataSet):\n",
    "    vocabSet = set([])\n",
    "    for doc in dataSet:\n",
    "        vocabSet = vocabSet | set(doc)\n",
    "    return list(vocabSet)\n",
    "#word2Vec\n",
    "def setOfWordVectors(vocabList, inputSet):\n",
    "    returnVec = [0]*len(vocabList)\n",
    "    for word in inputSet:\n",
    "        if word in vocabList:\n",
    "            returnVec[vocabList.index(word)] = 1\n",
    "        else:\n",
    "            print(\"单词：%s不在词表中！\" % word)\n",
    "    return returnVec"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练算法：从词向量计算概率\n",
    "其实就是计算先验概率（类别概率）和条件概率"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def trainBayes(trainMatrix, trainCategory):\n",
    "    #计算先验概率和条件概率\n",
    "    numTrainDocs = len(trainMatrix)\n",
    "    numWords = len(trainMatrix[0])\n",
    "    pAbusive = sum(trainCategory) / float(numTrainDocs)\n",
    "    p0Num, p1Num = np.ones(numWords), np.ones(numWords)\n",
    "    p0All, p1All = 2.0, 2.0\n",
    "    for i in range (numTrainDocs):\n",
    "        if trainCategory[i] == 1:\n",
    "            p1Num += trainMatrix[i]\n",
    "            p1All += sum(trainMatrix[i])\n",
    "        else:\n",
    "            p0Num += trainMatrix[i]\n",
    "            p0All += sum(trainMatrix[i])\n",
    "    p0_condition = np.log(p0Num / p0All)\n",
    "    p1_condition = np.log(p1Num / p1All)\n",
    "    return p0_condition, p1_condition, pAbusive"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def classifyNB(vec2Classify, p0Vec, p1Vec, pClass1):\n",
    "    #给一句话，预测是不是垃圾——比较p0和p1的概率\n",
    "\n",
    "    p1 = sum(vec2Classify * p1Vec) + np.log(pClass1)\n",
    "    p0 = sum(vec2Classify * p0Vec) + np.log(1.0 - pClass1)\n",
    "    if p1 > p0:\n",
    "        return 1\n",
    "    else:\n",
    "        return 0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试算法：使用朴素贝叶斯进行交叉验证"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def spamTest():\n",
    "    vocab = set()\n",
    "    docList, fullText, classList = [], [], [] #每个文档对应一个类别\n",
    "    for i in range(25):\n",
    "        wordList = textParse(open('email/ham/{}.txt'.format(i+1)).read())\n",
    "        docList.append(wordList)\n",
    "        fullText.extend(wordList)\n",
    "        classList.append(1)\n",
    "        wordList = textParse(open('email/spam/{}.txt'.format(i+1)).read())\n",
    "        docList.append(wordList)\n",
    "        fullText.extend(wordList)\n",
    "        classList.append(0)\n",
    "    vocabList = createVocabList(docList)\n",
    "    trainSet, testSet = list(range(50)), [] #构建trainSet和testSet索引\n",
    "    for i in range(10):\n",
    "        randIdx = int(random.uniform(0, len(trainSet)))\n",
    "        testSet.append(trainSet[randIdx])\n",
    "        del (trainSet[randIdx])\n",
    "    trainMat, trainClasses = [], []\n",
    "    for docIndex in trainSet:\n",
    "        trainMat.append(setOfWordVectors(vocabList, docList[docIndex]))\n",
    "        trainClasses.append(classList[docIndex])\n",
    "    p0V, p1V, pSpam = trainBayes(np.array(trainMat), np.array(trainClasses))\n",
    "    errorCount = 0\n",
    "    for docIndex in testSet:\n",
    "        wordVector = setOfWordVectors(vocabList, docList[docIndex])\n",
    "        if classifyNB(np.array(wordVector), p0V, p1V, pSpam) != classList[docIndex]:\n",
    "            errorCount += 1\n",
    "    print(\"错误率是：{:.2f}\".format(float(errorCount)/len(testSet)))\n",
    "    predict = classifyNB(np.array(setOfWordVectors(vocabList, \"used to treat moderate to moderately SeverePain\".split())), p0V, p1V, pSpam)\n",
    "    print(predict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "错误率是：0.00\n",
      "单词：to不在词表中！\n",
      "单词：to不在词表中！\n",
      "单词：SeverePain不在词表中！\n",
      "0\n"
     ]
    }
   ],
   "source": [
    "spamTest()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
